from mlagents.torch_utils import torch
from typing import List
import math

from mlagents.trainers.torch_entities.layers import (
    linear_layer,
    Swish,
    Initialization,
    LayerNorm,
)


class ConditionalEncoder(torch.nn.Module):
    def __init__(
        self,
        input_size: int,
        goal_size: int,
        hidden_size: int,
        num_layers: int,
        num_conditional_layers: int,
        kernel_init: Initialization = Initialization.KaimingHeNormal,
        kernel_gain: float = 1.0,
    ):
        """
        ConditionalEncoder module. A fully connected network of which some of the
        weights are generated by a goal conditioning. Uses the HyperNetwork module to
        generate the weights of the network. Only the weights of the last
        "num_conditional_layers" layers will be generated by HyperNetworks, the others
        will use regular parameters.
        :param input_size: The size of the input of the encoder
        :param goal_size: The size of the goal tensor that will condition the encoder
        :param hidden_size: The number of hidden units in the encoder
        :param num_layers: The total number of layers of the encoder (both regular and
        generated by HyperNetwork)
        :param num_conditional_layers: The number of layers generated with hypernetworks
        :param kernel_init: The Initialization to use for the weights of the layer
        :param kernel_gain: The multiplier for the weights of the kernel.
        """
        super().__init__()
        layers: List[torch.nn.Module] = []
        prev_size = input_size
        for i in range(num_layers):
            if num_layers - i <= num_conditional_layers:
                # This means layer i is a conditional layer since the conditional
                # leyers are the last num_conditional_layers
                layers.append(
                    HyperNetwork(prev_size, hidden_size, goal_size, hidden_size, 2)
                )
            else:
                layers.append(
                    linear_layer(
                        prev_size,
                        hidden_size,
                        kernel_init=kernel_init,
                        kernel_gain=kernel_gain,
                    )
                )
            layers.append(Swish())
            prev_size = hidden_size
        self.layers = torch.nn.ModuleList(layers)

    def forward(
        self, input_tensor: torch.Tensor, goal_tensor: torch.Tensor
    ) -> torch.Tensor:  # type: ignore
        activation = input_tensor
        for layer in self.layers:
            if isinstance(layer, HyperNetwork):
                activation = layer(activation, goal_tensor)
            else:
                activation = layer(activation)
        return activation


class HyperNetwork(torch.nn.Module):
    def __init__(
        self, input_size, output_size, hyper_input_size, layer_size, num_layers
    ):
        """
        Hyper Network module. This module will use the hyper_input tensor to generate
        the weights of the main network. The main network is a single fully connected
        layer.
        :param input_size: The size of the input of the main network
        :param output_size: The size of the output of the main network
        :param hyper_input_size: The size of the input of the hypernetwork that will
        generate the main network.
        :param layer_size: The number of hidden units in the layers of the hypernetwork
        :param num_layers: The number of layers of the hypernetwork
        """
        super().__init__()
        self.input_size = input_size
        self.output_size = output_size

        layer_in_size = hyper_input_size
        layers = []
        for _ in range(num_layers):
            layers.append(
                linear_layer(
                    layer_in_size,
                    layer_size,
                    kernel_init=Initialization.KaimingHeNormal,
                    kernel_gain=1.0,
                    bias_init=Initialization.Zero,
                )
            )
            layers.append(Swish())
            layer_in_size = layer_size
        flat_output = linear_layer(
            layer_size,
            input_size * output_size,
            kernel_init=Initialization.KaimingHeNormal,
            kernel_gain=0.1,
            bias_init=Initialization.Zero,
        )

        # Re-initializing the weights of the last layer of the hypernetwork
        bound = math.sqrt(1 / (layer_size * self.input_size))
        flat_output.weight.data.uniform_(-bound, bound)

        self.hypernet = torch.nn.Sequential(*layers, LayerNorm(), flat_output)

        # The hypernetwork will not generate the bias of the main network layer
        self.bias = torch.nn.Parameter(torch.zeros(output_size))

    def forward(self, input_activation, hyper_input):
        output_weights = self.hypernet(hyper_input)

        output_weights = output_weights.view(-1, self.input_size, self.output_size)

        result = (
            torch.bmm(input_activation.unsqueeze(1), output_weights).squeeze(1)
            + self.bias
        )
        return result
